# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
import tensorflow as tf
import seq2seq_model

from tensorflow.python.tools import freeze_graph

os.environ["JOB_ID"] = "10086"
os.environ["ASCEND_DEVICE_ID"] = "0"

## Required parameters
tf.app.flags.DEFINE_float("learning_rate", 0.7, "Learning rate.")
tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.5,
                          "Learning rate decays by this much.")
tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
                          "Clip gradients to this norm.")
tf.app.flags.DEFINE_integer("batch_size", 128,
                            "Batch size to use during training.")
tf.app.flags.DEFINE_integer("size", 1000, "Size of each model layer.")
tf.app.flags.DEFINE_integer("num_layers", 4, "Number of layers in the model.")
tf.app.flags.DEFINE_integer("from_vocab_size", 160000, "English vocabulary size.")
tf.app.flags.DEFINE_integer("to_vocab_size", 80000, "French vocabulary size.")
tf.app.flags.DEFINE_string("train_dir", "./model", "Training directory.")
tf.app.flags.DEFINE_string("output_dir", "./pb_model", "PB model directory.")
tf.app.flags.DEFINE_boolean("use_fp16", False,
                            "Train using fp16 instead of fp32.")
tf.app.flags.DEFINE_boolean("use_lstm", True,
                            "If true, we use LSTM cells instead of GRU cells.")

FLAGS = tf.app.flags.FLAGS

_buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]


def create_model(session, forward_only):
    """Create translation model and initialize or load parameters in session."""
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = seq2seq_model.Seq2SeqModel(
        FLAGS.from_vocab_size,
        FLAGS.to_vocab_size,
        _buckets,
        FLAGS.size,
        FLAGS.num_layers,
        FLAGS.max_gradient_norm,
        FLAGS.batch_size,
        FLAGS.learning_rate,
        FLAGS.learning_rate_decay_factor,
        use_lstm=FLAGS.use_lstm,
        forward_only=forward_only,
        dtype=dtype)
    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.global_variables_initializer())
    return model, ckpt


def main(_):
    ckpt_path = os.path.join(FLAGS.train_dir)
    with tf.Session() as sess:
        model, ckpt = create_model(sess, True)
        tf.train.write_graph(sess.graph_def, FLAGS.output_dir, 'model.pb')  # Generate model files via write_graph
        freeze_graph.freeze_graph(
            input_graph=os.path.join(FLAGS.output_dir, "model.pb"),  # Pass in the model file generated by write_graph
            input_saver='',
            input_binary=False,
            input_checkpoint=ckpt.model_checkpoint_path,  # Pass in the checkpoint file generated by training
            output_node_names="add,add_1,add_2,add_3,add_4,add_5,add_6,add_7,add_8,add_9,add_10,add_11,add_12,add_13,add_14,add_15,add_16,add_17,add_18,add_19,add_20,add_21,add_22,add_23,add_24,add_25,add_26,add_27,add_28,add_29,add_30,add_31,add_32,add_33,add_34,add_35,add_36,add_37,add_38,add_39,add_40,add_41,add_42,add_43,add_44,add_45,add_46,add_47,add_48,add_49,add_50,add_51,add_52,add_53,add_54,add_55,add_56,add_57,add_58,add_59,add_60,add_61,add_62,add_63,add_64,add_65,add_66,add_67,add_68,add_69,add_70,add_71,add_72,add_73,add_74,add_75,add_76,add_77,add_78,add_79,add_80,add_81,add_82,add_83,add_84,add_85,add_86,add_87,add_88,add_89,add_90,add_91,add_92,add_93,add_94,add_95,add_96,add_97,add_98,add_99",
            restore_op_name='save/restore_all',
            filename_tensor_name='save/Const:0',
            output_graph=os.path.join(FLAGS.output_dir, "seq2seq.pb"),  # Change to the name of the inference network that needs to be generated
            clear_devices=False,
            initializer_nodes='')


if __name__ == "__main__":
    app.run(main)
